{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fully Bayesian GPs - Sampling Hyperparamters with NUTS\n",
    "\n",
    "In this notebook, we'll demonstrate how to integrate GPyTorch and NUTS to sample GP hyperparameters and perform GP inference in a fully Bayesian way.\n",
    "\n",
    "The high level overview of sampling in GPyTorch is as follows:\n",
    "\n",
    "1. Define your model as normal, extending ExactGP and defining a forward method.\n",
    "2. For each parameter your model defines, you'll need to register a GPyTorch prior with that parameter, or some function of the parameter. If you use something other than a default closure (e.g., by specifying a parameter or transformed parameter name), you'll need to also specify a setting_closure: see the docs for `gpytorch.Module.register_prior`.\n",
    "3. Define a pyro model that has a sample site for each GP parameter. For your convenience, we define a `pyro_sample_from_prior` method on `gpytorch.Module` that returns a copy of the module where each parameter has been replaced by the result of a `pyro.sample` call.\n",
    "4. Run NUTS (or HMC etc) on the pyro model you just defined to generate samples. Note this can take quite a while or no time at all depending on the priors you've defined.\n",
    "5. Load the samples in to the model, converting the model from a simple GP to a batch GP (see our example notebook on simple batch GPs), where each GP in the batch corresponds to a different hyperparameter sample.\n",
    "6. Pass test data through the batch GP to get predictions for each hyperparameter sample."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import gpytorch\n",
    "import pyro\n",
    "from pyro.infer.mcmc import NUTS, MCMC, HMC\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training data is 11 points in [0,1] inclusive regularly spaced\n",
    "train_x = torch.linspace(0, 1, 4)\n",
    "# True function is sin(2*pi*x) with Gaussian noise\n",
    "train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We will use the simplest form of GP model, exact inference\n",
    "class ExactGPModel(gpytorch.models.ExactGP):\n",
    "    def __init__(self, train_x, train_y, likelihood):\n",
    "        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)\n",
    "        self.mean_module = gpytorch.means.ConstantMean()\n",
    "        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
    "    \n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running Sampling\n",
    "\n",
    "The next cell is the first piece of code that differs substantially from other work flows. In it, we create the model and likelihood as normal, and then register priors to each of the parameters of the model. Note that we directly can register priors to transformed parameters (e.g., \"lengthscale\") rather than raw ones (e.g., \"raw_lengthscale\"). This is useful, **however** you'll need to specify a prior whose support is fully contained in the domain of the parameter. For example, a lengthscale prior must have support only over the positive reals or a subset thereof."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Sample: 100%|██████████| 200/200 [00:12, 15.85it/s, step size=3.88e-01, acc. prob=0.971]\n"
     ]
    }
   ],
   "source": [
    "# this is for running the notebook in our testing framework\n",
    "import os\n",
    "smoke_test = ('CI' in os.environ)\n",
    "num_samples = 2 if smoke_test else 100\n",
    "warmup_steps = 2 if smoke_test else 100\n",
    "\n",
    "\n",
    "from gpytorch.priors import LogNormalPrior, NormalPrior, UniformPrior\n",
    "# Use a positive constraint instead of usual GreaterThan(1e-4) so that LogNormal has support over full range.\n",
    "likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.Positive())\n",
    "model = ExactGPModel(train_x, train_y, likelihood)\n",
    "\n",
    "model.mean_module.register_prior(\"mean_prior\", UniformPrior(-1, 1), \"constant\")\n",
    "model.covar_module.base_kernel.register_prior(\"lengthscale_prior\", UniformPrior(0.01, 0.5), \"lengthscale\")\n",
    "model.covar_module.register_prior(\"outputscale_prior\", UniformPrior(1, 2), \"outputscale\")\n",
    "likelihood.register_prior(\"noise_prior\", UniformPrior(0.01, 0.5), \"noise\")\n",
    "\n",
    "mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)\n",
    "\n",
    "def pyro_model(x, y):\n",
    "    with gpytorch.settings.fast_computations(False, False, False):\n",
    "        sampled_model = model.pyro_sample_from_prior()\n",
    "        output = sampled_model.likelihood(sampled_model(x))\n",
    "        pyro.sample(\"obs\", output, obs=y)\n",
    "    return y\n",
    "\n",
    "nuts_kernel = NUTS(pyro_model)\n",
    "mcmc_run = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps, disable_progbar=smoke_test)\n",
    "mcmc_run.run(train_x, train_y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading Samples\n",
    "\n",
    "In the next cell, we load the samples generated by NUTS in to the model. This converts `model` from a single GP to a batch of `num_samples` GPs, in this case 100."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.pyro_load_from_samples(mcmc_run.get_samples())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "test_x = torch.linspace(0, 1, 101).unsqueeze(-1)\n",
    "test_y = torch.sin(test_x * (2 * math.pi))\n",
    "expanded_test_x = test_x.unsqueeze(0).repeat(num_samples, 1, 1)\n",
    "output = model(expanded_test_x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot Mean Functions\n",
    "\n",
    "In the next cell, we plot the first 25 mean functions on the samep lot. This particular example has a fairly large amount of data for only 1 dimension, so the hyperparameter posterior is quite tight and there is relatively little variance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQAAAADDCAYAAABtec/IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAyCElEQVR4nO2de3hU1bnwf2tPriSZTBIC4Z6EOwgkMIigoEJsUUAFKVXa4/EIWlvb09Ny7MHKV+ypF7Tl9PN2ekQ5z/NVpYIRqVYEhagocptwk5sQEu4JAZIJBHKb2ev7482ES7gEEpKBWb/n4SGZvbPnnT17vetd720prTUGgyE0sVpaAIPB0HIYBWAwhDBGARgMIYxRAAZDCGMUgMEQwhgFYDCEMGGNvYDb7c6q/fEOj8fzH429nsFgaD4aZQG43e6BwECPx7MMGOh2u9ObRiyDwdAcqKZIBHK73S7gBY/H85NGX8xgMDQbjV4C1OIGdl/shOnTp5uUQ4OhhZg1a5Y63+tNogA8Hs8yt9v9A7fbPdHj8WRf6Lzf//73l7xWcXExbdq0aQqxrhrBLmOwywfBL2OwywcNl3HmzJkXPNZYH8ALbrf70dpfvUBiY65nMBial8aGAV8H8msjAS6PxzOnCWQyGAzNRKOWAB6PJx/Ir/11WePFMVyv+Hw+ioqKqKqqQmuNbdscP368pcW6IMEuH9SXUSlFZGQkKSkphIU1bGg3lRPQYLgoRUVFxMTE0KFDB5RS1NTUEB4e3tJiXZBglw/qy6i1xuv1UlRURMeOHRt0DZMJaGgWqqqqcLlcKHVeZ7ShCVBK4XK5qKqqavDfGAVgaBa01mbwNwNKKS4nt8coAEPQUVhYSFZWFkVFRS0tynWPUQCGoOP555/nm2++4bnnnrvia2zYsIGFCxeyfPly3nzzTfLzxVe9cOFCnnrqqaYS9YJ4vV7uuuuu88rVu3dvli9fzvLly5k9ezZer/eqy3MhjAIwBA0ul4uoqCjmzJmDbdvMmTOHqKgoXC7XZV3H6/Xy5ptvMmHCBEaNGsXUqVOZMWMGACNHjrwKktfH5XKRlpZW7/XMzEzS0tIYNWoUo0aNYtq0aUyePPm81/B6vcyePfuqymmiAIagYfv27UyfPp0PP/yQiooKoqOjueeee5g1a9ZlXSc7O5vMzMyzXktISGDDhg2kpaWxYcMGNmzYQE5ODlOmTCE3N5fExERycnIYP348OTk5xMfHM2jQIDZu3Eh2djaDBg0iLS2N7Oxs5s2bx+OPP860adMAyMnJITExkYyMDBITE+vOLygoaJC8LpcLr9dLSUkJOTk5lJWV1cmVm5vLhg0biI+PP+vY5SrFC2EUgCFoaNeuHU6nk6qqKqKioqiqqsLpdJKSktKk75OZmVmnIObOnUtJSQkg1sGMGTOYMmUKTqeTuXPn8sQTT7B8+fK6wZ6dLZnuEyZMID09ncmTJzNlyhQSExOZO3cuXq+XadOmkZ6efl4L4EKUlJSQnp5Oeno6y5cvZ+7cuUyZMoWcnJw6Wc88FpCnsRgFYAgqiouLeeSRR5gyZQpz5869IkfgxIkT+dnPfsbUqVPrXisoKCAzM/O86+0pU6YAMGPGDKqrqxk0aBAxMTEMHjwYr9d71mw7cuRIZs+eXfc3AIMGDcLlcpGZmcnjjz9OYqJkxJeWljZIXq/XWze4CwoKGDRo0FnH8/PzKSgoOO+xxmIUgCGomD9/ft3PL7300hVdw+Vy8cQTT/Dmm2+SlpZGQUEBr7766lnnBJYA06ZNY/bs2WRkZDBx4kQyMjLIzs5mwIABhIWFUVJSQkFBQZ0imDhxIk899VSdUnjmmWfqTH6AadOmnbUE2LBhw1nLkcBgXr58OcBZspWVlZGYmEhBQQH5+fmUlJTg9XopKCiodyw/P59OnTpd0f05kybpB9AQpk+frk01YPMQjPLl5eXRrVu3ut+DPdMu2OWDC8t47r2eOXPmBcuBTRTAYAhhjAIwGEIYowAMhhDGKACDIYQxCsBgCGGMAjAYQhijAAzXHV6vt64QaPbs2XWFQI1h+fLlPP744xd9z/MV/4DE/ocOHcqGDRsu+lpLYBSA4bojOzubkSNHMmrUKDIyMprkmqNGjbro8QsV/4Ck8AYSkwKUlZWRlpZWr2ahuWlUJmDthiDptf8Gm63BDA2huhry8hTh4VfeIKRLF01ExPmPDRo0iDvvvJOJEyfWFc7k5+fXK7QJ5PoHFEYg1bagoIAFCxbwyCOPkJOTwxNPPFGX+Re4zuUW/8THx1/w2LnXBOrJOnv2bJ599lmys7N59tlnWb58OU6nkxUrVjSqLqCxFsAkwB3YC+CMFuEGQ4uRmZnJqlWryMjIYPLkyXW59lOnTiUjI4O5c+fWzeiB+oBBgwYxcuRIcnJyGDlyJKmpqYwaNYqRI0cyd+7cumvPmDGDtLQ00tLSmDt3Lk899RQjR46sK/O9GBMmTGDhwoX10oPPveb5ZA3UGgRqGXJycnj//fcbbeE0tivwmW3A04HPGiWNISSIiIDu3TXh4VcnDX3hwoV1vQASExPJzc0FqFdMk5CQUPezy+Wqqwq8FFda/DNq1CgmT57ME088cdFrnq8o6ExZQQqYfD4fTz/9dN3fXglNUgxUuyloSe0moRekuLj4ktdqye4oDSXYZQxG+Wzbpqampu53v99/1d7ryJEjvPfee8THx1NQUMDDDz/MokWLiI+PJy8vj7y8PBYtWkR+fj5HjhwhPz+fdevWsWfPHtatW8eDDz7Ixo0bWbduHbm5ufzqV79i3bp15Ofn89JLLzF//nwGDhwIwC9/+cu63wPXOXdWLigo4I033qB///5kZGQQExNTd72dO3fy9NNPn3XNkpKSS8r6+eef079/f8aPH09MTMxZ99a27QaNNWi6zUF/4/F4XrzYOaYYqPkIRvmupWIgr9fLrFmzLrsRSXMTFMVAtfsBvlj788DGXs9gaGlyc3PZtGlTS4vRLDQ2CpAFvOB2u5+sfclEAQzXPKNGjWLEiBEtLUaz0Fgn4DKgaxPJYriOUUpRWVlJZGSk2R/gKqG1pqqq6rLur+kIZGgWkpOTOXz4MD6fr25vQMsK3jy0YJcP6suolCIsLIzk5OQGX8MoAEOz4HQ6cTqddb8Ho6PyTIJdPmgaGYNbxRkMhquKUQAGQwhjFIDBEMIYBWAwhDBGARgMIYxRAAZDCGMUgMEQwhgFYDCEMEYBGAwhjFEABkMIYxSAwRDCGAVgMIQwRgEYDCGMUQAGQwhjFIDBEMIYBWAwhDBGARgMIUyTdAV2u91mQxCD4Rqk0QogsC2YwWC49jBLAIMhhGnWpqBma7DmIdjlg+CXMdjlg6aRsVkVQEM7mAZ7N1YIfhmDXT4IfhmDXT5ovIxmCWAwhDBNEQXIkv/cE5tAHoPB0Iw0eglQuz1YwiVPNBgMQYdZAhgMIYxRAAZDCGMUQBBTWFhIVlYWRUVFLS2K4TrFKIAg5vnnn+ebb77hueeea2lRDNcpZnfgFqCmBg4ehG3bFGvWKHbtsjh+HCorwe+H1atj0bqy7vw5c+YwZ84cwsOj2LzZS+fO4HC04AcwXDcYBXCVOXQINm+2KC2VAb9xIxQXK8LDITYWunTRDB5sExam2bZNsWmTRa9euzh06DccP74IrStQKpqYmPEkJ7/I1KlhOJ2QkqJp107Tqxd07arJyNBGKRguG6MAmpiKCli+XPHtt4qDBxVlZVBUpNAa2rTRdOumCQ9XHD4Mp06JUli3TnH4sCI+XtO5M7Ru3Y7ycidlZVUoFYXWVcTGxhEdncKuXRAfDz6fIi9PsW4dREZCcrJm+HCN0wndu1tcA0lshiDAKIAmoLwc5syxOHZM4fOBUprDhxVHjii8XoiOhmPHZOY/cECRkWEzZAgUFCjWrlVYFjz4oE10NCQkaNq3h//7f4tIT3+Ef/qnqSxe/CYHDhQxbZpNaSksXarYuFHRsSN06GATFaXYuVNRUKBITdUo5SQpyUFqqmbyZJtWrVr6DhmCFaMArpCqKpg3T7Fzp6zffT747juZ2W1b0aoVWBYoJTP2yJGauDj529JSxYYNcn7nzpqEBMWhQwqA48fh0CGL+PhsUlLggw8Ux469RmqqZutWzZAhmp/+VFNcrHjnHYuyMsWxY4p+/TQpKbKMKChw4HRqoqPhj390UFUF3btrWrcGp1PTq5emTRuRzRDaGAXQAE6cgP37VZ3jLicnkaNHZS1eVSWmvNMJSUmatm0VJ0/KQEtJgXbtNLYN5eWKVatg714L0JSXQ8+eEBWlcDhksAKUlSmGDLE5dgwKCxU1NfL61q0yy3/zjUYpRc+emsxMmz17FFFRGqdTs2WLIiEBpkwpZ8uWBN5+2yI5WfPzn9toDQcOiGLau1eWHSBKQGt5j8hITZcu4peIjGz++2xofowCqMW2YdcuWVcHBkSAqCjNunWKTz+1iI6Gvn19DB1qc+yY4tgxyMuzOHZMU1ioaN8e2rfXHD2qKC/XrFtn1SoJRc+eNuPG+dm71+IHPxCTv7paFMjhw+DxWKSlaSorZV3frp0mMRHcbk1ams2HH1rs2KGoroacHIXLJT4Gvx9SUhSTJtnk5SlWrowkKkrx29/6+e47eOopB0lJ8Lvf+bFt2L5dkZkp/ogzqayEffsUK1YoqqrONg86d9b06aMJC/EnRmvweuHQIcXx43DypLxmWRATIxNBmzbyvVnXQJA9pL/O/HzYssVCKfmyunfX3HmnXffF7dsHv/2tg337LLp21cyd62PZMouPPgqjvFxRVGShtSY+Xtbt3bppjh1ThIVBWpqmRw9NbKyma1dx0Hk8io8+svjf//UREwN5eYodOxRKQW6uokcPTUKCOPJiY2V2PnIEVqyAt98Oo7JSFEaHDvDzn9uUlYlfYfNmxc6dMHOmRadO0KuXIipKlMTx44rXXvOxerXi8ccddO6s+eMf/Rw4oJg/32LkSJvkZPm8UVHQo4emRw+A08pBa9i7F5YutfD7Ra4hQ+yQcDSWlMDatRY1NaeXTC6XRGDatpVBr5RMICdPyhJu1y5FSUn9iaRzZ03v3prw8Ob/HBci5BTA8ePw5ZfyIKena8aNs+uthd9+W/Heew5OnYIJE2zuv1+zfbvi178OY88eOH48AqdTkZws6+m2bWXwl5fD0KGa9HTNyZOy1r/rLpuICBngS5ZYjBhh8847Fvv3K9q313TsCIcPK55+2l878C6Ejc8He/Yo/vxniy+/tCgqEt9Dv36a8eM1ixZZFBeDxxNJWJgiLEyUyltvWSil+PhjH2+8YTFhQhi33qr53e/8fPWV4tQpkfNCs7tSkJoKqam2SGLD6tWKlSsVrVtrbrlFX1f+BJ8PvvgiiogIi8REze232w1aErVqBcnJEpY9U4GCKNF9+2DZMgufT36PidHceONp31BLEDIKQGZSC6dTM3q0XU8LnzoFzz9vsXmzhdNp07+/pn9/TUWF5tVXHeTny3ky2KtxuSLJytK0bauprpY1flaWzalTitxcxYgRNu3by0PwxReKl1920KmTZu9exbBhmhtuEMWzeze0bq0uMfiFsDBRNLNn+1m82GLCBJsjR+D//T+LhQstKipk6dCnTxWW5aCgQNXmFiiSkmD/fgd33qn59FMf//EfDr7//XAefdTHuHE2Cxda9Ogh+QSXwrJg2DB5yA8fhoULLdq2FUVwLePziZXj88HgwVWkpdlNdm2lqPWvnL7miROwdq0oYIC4OFEIzRm1UfpcO+UqMX36dP373//+kucVFxc3aSeWqipYvNjC5dLcdlv9mWrvXnjlFRksTqcmORlGjLCprISXX3awb58iJUVmeadTvOc9e3r5wQ/iWLNGBt3IkTa2DZ99JkuFQYPknn79teLtty0KCxVTpvjp3RsKCqCm5rQQX32lGD5co/VpEzPwc+CrCawxO3WSZUVUFGzaJEuNvn1Pf39HjsBf/2rx3/8NrVtbjBljc/KkWDxaywMeFgZJSfD66z7KyuC3vw0jPl7zr/8q4cItWxR3323XOSUbyv79sHKlxbBhNp07X/r8pv6eG8vatYr9+xV33in3oSXk83ph3TpFZaU8CGFh0KuXTWpq/YhNYWEh999/P/PnzyclJeWi1505cyazZs06r412XVsA69bJlzpmTH0TbuNGxbx5Fvn5kJgoZtu4cX5ychy8/LKDnTsVvXppbr5ZNHZyMowebZOVpcnO1nz5pUVWlo3DIbNGTIxm4kSZ1f/6V8WXX1o4HPIlDh9uEx6usG1RQhERMmg/+0zx7/9u07r1pT+LbYsJ+dVXpx+QlSsVDz1k07OnKLbkZJg2zWbs2CP8z/+ksG6dwudT3HijZAl++aU4ECsrYdy4MHr21Lz8so8PP7R48kkHd99tM26czeLFFt26aQYMaPjk0KkT3H+/zRdfKHbsgO9979qwBqqqYNEii0GDNDfe2HQz/pXgcsEdd5xePtTUwI4disWL1VmTglLw+uuzWLt2Lc899xwvv/zyFb/ndWkBVFbKl5qZqenZ8+zPt26d4r33LPLyoH176NgRoqM1hw7Jg7tzp0W7dppRo2zWrLHo1EkG9q23apYvV5SXKwYMOEz79sl8+qk4EL//fZvCQnjuOQdHjihuvNFm1CjN5s2Ktm1h7Nj6D1bAuTR69JU/dGVlMH++RWACiI7W3Hyzpry8mOPH27JihTgDd+xQ5OeLgjh0SEKKAYVx8iTceqtm+HCbv/7VQVycZtw4Xbdcufde+7JTjIuKYPlyiXRERJz/nGCwAPbtE6vlvvvqyxkM8gXQGnbvVmzfrpg82UlNTWW9c6Kioi7YJDSkLIDduyWcNmHC2V/q+vUy4+/fDykpcPvtmm+/tThwQOPzSUZeYiL8+tc1LFoUxp49ip/+1M/3vqf5+mvFBx+Ix9zh0HzwQRSJiaIo9u9XPPaYA59P8e//7qN/f/EnLFhgERkJY8acf4B/+qnFD3/YuBknPl6iDRkZum4wf/21orAwmvbtxcl58KDigQdsamrg3Xct+vXTdO8Oa9dKJqLDAZ9/rjh2THwAXq9kNY4da9Onj+bddy1uu82mQ4eGy5WSAj/4gc2CBaLgGmLhNDebNimKi+GBB5pn1rdtKCyUe+71nj0WzzTvtZaZf8MGKCyUJabW4h8oLVX06rWLvLzfUFHxd+AU0dHR3HPPPcyaNeuK5LquFMCKFTKznTmwtm9XvPGGeMfbt5eY+qpVFkeOwKBBNkuWWNg2/PjHNitXWnz5ZRgPP2xz5502W7YoFi60GD7cpn17+PxzC69XvrDqanj/fYVtK5580k/37vJ+Ph+8956FZcGkSfUjDAAbNigyMprGc56VpZk/3+L++21iYsT0Li6uoFWrOFassPj2W0VhoZjnzz3nZ+VKxQcfKKZOtTl0SKyhgwdh2TJFejpMnizLpbffdjBkiJ8hQ2DVKslvEMdfw4iIgB/9yOb99y2GDLHp1Knxn7WpWLlSERERMLebHtsWK2vPntPJVkpJXkffvhqXS9eFVvPyFBUVij174OhROT88XPxQaWk+/vd/Ldassdi8WZZvWrcjOtpJRUUl4eGRVFVV4XQ6L+kHuBDXhQKwbfFEZ2badO0qr+3dC6+95qCoCNLSIDxc4/FYtG2rmDDB5vPPYckSi759bcrLxQIYM8Zm9Gibo0cVH38sx2JjJXx34oSscwcP1mRnh1FZqXj0UX/d+4Fo6r/9zSI2Vtb95zN/fT6JE0+a1DQzj1IweLDN2rWy1g8QGwt33WVz553wpz9Z/PnPDoYOlbDnsGF+srPFP/KXv/hZvlxyAnbsgN/9zuKhh2wee8zP3/4m9+WHP/RTUiJ1DPfd1/AlgVIwcaLNxx9b1NTYpKc3yUduFJ9/LhGR/v2bdvCXlcE331hUV4vDtm9fm7FjZaDv2QO7dlkcPCgFYkpJ/YjXC61bS+Tm3nttEhLgk08UmzdbvPqqg8JC8SHFx0PnznJ+TQ2UlR2mQ4dHmTFjAhs3vt+ohjGNVgC13YC9QLrH45nT2OtdLuXl8P77FuPH2zidUFoK//VfDnbvhgEDxAReulTRqxfMnu3n3XcVCxZYJCRIgs7x44q0NPkCIiIkVFhTIwU8H31kkZoqCUIDBmhee81i1SqL++47yY031neTv/eeRUqKJAVdaPn48ccWd93VtGZn165i0g4YUD+FVyn42c9s/vEPKUlevFgiArfdZnPvvX5ee82itFSRne3jrbcsPvzQ4i9/Eb/CzJk227Zp5s51kJkp9+uPf3Tw6KN+EhMbLt+YMTaffGKhlE1aWpN+9Mvi888VKSnQu3fTDH6/Xxyrx49LBGnIEJvvvpPajK1bLbZtC4T/JJfg0CHYtEn8Rh06aCZMkCXXggUWq1Y5iIqStPFlyxSlpeKgjIqSZYPPB/HxkkTUunU2Tid061bEv/zLiEZ9hkYpgEArcI/Hs8ztdj/qdruzarsENwsFBbBuncU//ZMkycyebbF6tUW/fjYHD1rs3q0YOtTm1Vf9vPyyg//8TwetWmkSEmSAVlWJmd6xoyY720FMjMRgIyPl32OPyUB95RWLjz+2eOQRP2lpUFxcfwD//e8WnTtLGu+Zobkz2b1b0kRjY5v+XowZY/PRRxYTJ9aXLSZGMvd27lQUFip+9COblSulWvHuu23i4uDPf3bQsaPmT3/yk5ureP11i8cec3DrrTbTp9vMmWNRUKBwu21++csw/u3f/HXhzoZw551iCTgcDQsTNjUrVijatGmawV9aKsvBw4chIUGsrZMnpcCrf3/N0KGn3yM/X/pB7N6t6NJFc/fdNtXVsGCB4pNPHDgcmpgYzaZNslw7cQKcTkWrVpp27cRiTE8X66ldO01RkRyPi6Muz6QxNNYCGAzMr/05HxgIXFABNOXWYB5PBFVViltvreJ//qcVH38chcMBRUVh7N8P9957gttvr+T992N46KFwYmOrsSyLdu1qKC4Oo1OnKpxOm4ULI2nd2k/Hjn7CwqBzZx833FBDYaHFH/4Qi9bw4IPltGtn136G+jIuWxZF+/Y+tm0LZ+zYCs73MUWzt2L8+FPnPd4UtG4dzhdfQPv23nrHWrUCrSNISYE33ghjwoRT9O4NmzaFs39/GGPG1HDihOKvf41h6NBKfvELi1On4IUXnKxercjKOkV8vGbZsij69avgmWcUCQmaZ54pbXB9wODBsHhxNG53FRER9WW8Wng8EURHa5KTaxp878/3HObnh/HuuzFER2t6964mI8NHerqvnhNv61aL3NxItIZOnXwMHVqDUpCf7+DXv47h4MEwEhN97NoVzrFjDo4etaip0cTF2WRlVZOa6qOiwsLrVVRUiD9Jaz8+n59f/OIkLpeuk7G4uHFKoLEKwHXO70kXO7kptgbz++GDDyxuuEGTlwf33+/A75e1Ups2mp/8RMpet2xx8sorTmJjJdyVlCRpwFpHMmiQ5uDBKOLjbUaNku48t94qM+HatZJq63TCjBl+XC6AqAvKuHSp+B7WrrX4l3+xUer8eZ0ffihef5frKkz/dTJBdrZFerqPNm2S6x3//vdlCXD77TZffulk0iSbO+6QY1u2iHXw859rNm2KIi9PZvu5c2HePM0XX8TgckmS06FD0SgFbdrYTJ3anvHjbR54wG6QZfPQQzB/fjyZmfq8MjY169crWrfmLP9IQ2nTpg3Hj0vExuNRtGunefHFgG/n7GeislISrioqpDnLj3+sawcuvPOOYtEiB5YlRWIRERKODZSRt2mjGD3az803Swryrl1w7JiiQwfNHXdofvQjm7i4wFBtVU/GxtBYBeAFLmM12Dj27IEFCxxERGheeslBdTW1BTQSDquoUBQXK1avFoXQrp2mslJKavfuVfTrZxMVJSbxTTdJHr7bLcuHd96RSrvu3W3+8z/9DSrY+OQTi+7dbb75xuKBB87v8Qfx+nfpomuVydXlnnts3ngjmp/97PzH77rLZt48i6FDbT74QHwnADfcoLnhBql5iIqS8NiSJYoTJxRZWZr77vPxzDMONm+W4+Hh4sxMSoJ//EOxaZODW27RDcoEnDTJ5i9/iaZDB/kurhZbtyrKy2HEiMsf/Fu2hFNaarFzJ/TpA88+66/n/NRaEsqkJFsmkUAab0EB/PnPFtu3W1RWSrJZTY0iKUkKxrQW30BSEkya5Oe99yz+8AcH7dtLQdj/+T8+oqLEUbtihXXWewa44YbGh5EaqwDWcdoKSAc+a+T1zsK2pWnG9u2K9etlIG/aJJlwAwZIDf2JE7IG8/kkzlpTI1VXRUWK775TREQo+vaVdb7PJ6HAW26xSUqCrVth5kyxIMaNs3nooYY75z78UHwNa9da3H33hRNejh6FgwfVeZOBrgbh4TB8eCWLFzsv6GycNEkKkm6/3WbRIot77z19Xu/eUrG2fbuiZ09ITLT59FOLqirFs8/6yc6WtarLBUeOKIqKFGlpmoICCVPl5Vn06SPK96abzh/qVArGjz9FdraT++9vWKHN5ZKXpzh06PJCfVpLpuXRo2J6R0TA44+Lc/lMSkrgq6/EmTpggM348fIeBw/KpLBggfRuaN8eWreW6s6iInHouVwyUfXrJ8Vjx47BjBlhdOliM2mSVGnedJOmpESRnq4ZNerC1YONNf+hkQrA4/Fku93u39TuD+hqrAOwpkbW006n3NyKCllzHz6s2LdPceAAtG2riYiQWuz4eJnBN26UzjidO2u++06RkyO5/wFnTIcOkhQzbJimuBhef92ivFzOf/JJ/2Xlvds2LFrUinHjJFPw1lvtC87sJ05IqPFHP2reFNO2bW201qxZoxgypP5DEhYG990nTsPhwyVWP2HC2RZMQBFs26YYNEijteYf/5Cl0b/9m5833nDQpYukNefmWkRESKKK3y9Osl69pBeiZUlMO+GczeMcDsk5mDfPYvLk+sVZjWHPHti5U11WtOWbb6QrU0aG5Ed07uxn+PDTf6+1nFNcLL6PsWMlHFpSAv/4h8W+fdRWXUJqqhRG+f3S1CUyUqyrvXslk9Tvh08/VbWefJg920d6umbFCotf/KJh1mdT0RR7A75Y+2Ojvf+2LWv8Y8ekfPbECTh+XDK2oqKgZ09JoujYUVpiHT4sZlVKiiSrfPutrPeysmz8fum4c8stfrxeSZldssQiLg4efdR/RdlpJ05IvkFWVgW5uU4yM20ulH9x6pT4Kn784wsvDa4m/fuLAli9WmaUc4mNlSKmb76xuOUWm3ffFR/FuU0s+vSRRiBbtyosS3PkCHz4oaM2rCWKecwYm/x82LHDYs0aaW4aFaXZsUMGlHyXUj49bJiue4/wcElQmjdPlGRTNBvZu1d6PDTU4tqxQ56bm26SZ2bHDotJk2yOHvUBYlWuWSMD+6abbG6+WSI9OTnSg3HVqtMJP3fdZbNtm0ViooSnT51S3HyzNGn55BMJL2sty9UnntBMmiR+k40bJfbfEs9KUCUC1dTAxo0RdOggobSvv1YopZk8WXP8uLTlqqhQeDwSRw0Lg9WrpeNOt27SIrtHD01qqmbXLsWGDYr16x106KD59a/tRq038/Lkej/+sc28eVEMH27Tpcv5zz1wQEzEyZPrD6jmZMgQzfr1ipwcxciR9ZVA27aQmWmTmyu5CW+9ZV3QJO/bV7LYtmxRtG5ts22bzG79+tns2CGtx+67z89XX0mPxAULHPTpY9Oqlex9kJEh5dUffyyZl7Gx4YwcKeHWSZNs3n5b3juqvr+1wezeDdu3N2zwl5WJA7dnT83IkTZLl4oi7NxZU1MDK1ZEEhYmZc733CMt1VauVLz1ltSRREdLfn7btppevWxKSy3KyyUD8sgR6f60fj28+aZ83ptv1iQna/r1k54R3brJkvSDD6QL1LhxLVOIFFQK4MSJQr744gH69n0Pr7cdt91ms3+/mPSxsZqICNGsBw5InbvDIQPftkFrRViYJP54vWLiduvWeJlsWzzniYlSFJSdbdG/fzWpqec/f8UKRWVl8+WYX4qBAzV79sC8eRb33FNfCaamgs8n6dGTJ0v+flaWTdu2579ewFm4bZuid2+bL79U2LYMqLIyi2HDNEePymDct08shK5dNYWF8Pvfh5GRYXP77RqfT1qcgXjNf/hDee+7777wkupibNqkOHr0/IVXZ6K1JNr4fIqJE0X+oiJRPps2ia8pPBwyM6tJT7dZtkzxu9858HrFYunZUxMXJ1ZpZKQ4mdPTbYqLJd7fubNYS1u2iCUwdarNbbdpsrMtxoyRzlAgRWl79oj11JJdm4NKATz99PNs3bqWqqpnyMx8jY8/turMzpMnFadOSZplQgLExUklX5s2mrFjxbS8kCPuStm+Xdpvjx4t9fFvvWUxbpxNTY3/rPO0Fk9/Xp5i2DCbjh2bVo7GkpoKnTpJNp5twx13nF3v362bRinNp5+KEvjkE5n5LpboE1gaDByo8HgUixdbFBaKAuzRQ3PLLTbffmtx7Jj4BJYskerMffsUc+YovF4n994r/gHLksSa+Hj4058c3HSTZsyYhpnDWkuYLiFBHGYXIz9fzPmsLBvblkKnTp3EuvzoI4sbbrDp1UscgcuWOQGpjszMlI7LS5ZYfPCBOAfj46Ut2MaNin37HDidEtorLxez3+2WmpT/+i+LL75Q/OEP8szk5ChKSxUDB9oMHtzyJdNBUQ7scrmorKxf4ghRJCWVk5AA6enQs6fN4MGyhnK7r14bqvx8yM2V6rgBA3TdgzNxojiriouLadWqDbm5irIymQEHDGjZNNczuVgpqzQUlWVTeLg0nEhLE8/8vn2nP+e2bdKv8O67G+agO3hQBvkXX8jauKJCUrHj4mSdXVUlStvplPBraekpoBXV1ZKkNGKEpMu2aSOp3Rs3yrIhKup0J6TU1LObkh46JIrjjjsu3p+wpkZSsNu2lY47//3f0qQlIUHacnXuLGv9/fsVcXESXYqNLaNbNyc9emimTw9j506ZfLp00bV1EbIMiIsTf0/37rIHQ2mp+BN27VL8/e9SDam1PCMOBwwbVt8heqU0tGT5YuXAQaEACgsLmT59OosWfUhVVQUQjcs1ni5dXsThSMHplGSe/v0lfJeUJA/FmV1z4OyOOmd22LkQZ/69zwfbtokGT0nRdYUrubkBL+7pv/N6vbRvL8kszRHbv1wa+mAE9jLYs+d0A8uTJ8U8HTZMClk8HklISUs7u1PR+f6PipJahM8/lxBsXp7ksCcny4YnBw9KD8HjxxW33VaG1nGcOiWdjSsqFD6fKPXevcVUPnJEvsDu3cVxuHevWIEFBYqSErEEe/XSZ8kQIPB58vJEASUkyGc7eFBx2202AwaIl76oSLFvnzRXdTo1kZHSCObrr0+RkxNLcbEM8pgYKcaxbVFYCQmBoh+JNm3bJk5OkL4SkZEwerSs9QPyNzVNoQCCYgnQrl07nE4nNTVVhIdH4vNVMXhwHL/9bTJut48jRwJxXVnnHTggvfS7dTvdlTXQZLFTJ8mhvtS66sQJ6fFfXCwPf1gY/PSnp9efso5TPP64n6Rz8huLiytp08ZZ75rXGoGWYufWLkyYINGOG2+0ue8+yRLculX6HLZrd/5rBcK2paUyYA4dgsWLZRm1f78M3rAwKCmR8OuaNZH4/eKwGzAASkttDh0Sy2TvXkmnramRyrmaGimwGjpUk5SkefDB052MA9TUSGv1/fulGMfrFblTUzWPPSZ1ECUl4m8oKpKlXatWssdijx5w991+li61WLVK+iGcOhVDu3ZSSVpZKbP8rbdKrUjAAk1P16xZY5GUJK3aCgoUZWWaxx7TDBzY8uZ9QwgKBQCizR5++BEGD55Ibm42hw8X1cXxA732z+y0evKkpHmWlspD4nDIl1tZKRtwVFSoejPDmTN+q1ayhj0zTq61dLs9eFBmgcY27LhWCQ+X9euKFbIMGD3apm9fzVdfSSfgoUPrNwhRSpRuq1YSoRkwQHPnnXLs008D3nMxhXfsUJSVRdGpk3jB27aVzsbJyZoOHWSwVVdr9u2TWTY+XpOXJzshRUZSFylQSsKZHTqI36V7d5kAZH8GzdChNs8/7+CzzywiIzVxcZrUVHEe9+ih2bXLYtkyxZIlcOSIpPHK9TQ1NT7Ky8NxuST8N2CAprpa0a2b5tQp2RfA4ZDNYI4eVURHy/IksMXbtULQKID586WmqLi4mH/+55cueX5MDLUeVRnAfr/M2nv3nrZ0wsOhe3dpRnG+dazWMmts2iThRZDEovPFzUORESM0x45Jw5HMTM2IEZIQtHq17CyUlCTO10v1B/je9zTf+56f3bvhjTekYjMtrYqjRyPYtEkG57ffKqKjVW2uh5janTppOnWS7zY1VXPypEQTTpyQ927dWiaC776DRYsU+/ZJPX5NjbxuWbJ8iYuTpLEjR6QCMjpalpHx8TYVFdIcJixMOhnV1EBysmL06BOMHetk7Vqx3ZOSJCktP18yS48eFeuxSxeb4mJF9+7Qo8e1N2EEjQJoLA4Hda28A1RVydIhJ+f0FltwtiWQnCwPdlSUGfTnIymJuhBZdrbFwIF2rWUm0ZlAZCE2VlJ/L7b06toVZs2yqaiw+dvfTrF0aQQJCZoTJyRV1uGQwd2hgzRpLS6Gb78VR2tJSSE7dkymR4+/ERGRQnW1LN0CzsywMFl6VFTIQO7ZUzZlCXzPDoek5K5cKTX8NTWSJt6zpzgY+/TRDBggXaI2brTIz7fIzrbo3Pl05Ck2ViyRggJJ092+XcKJTdXcpSW4bhTA+YiMPP8a13D5DBggZv369YqFCxUulzjqAnH348clUSbQsdjhEJM8La3+dmLR0TB2bAUPPxxHaSm8/ba0vcrPl/bix45Jdlx0tNTwDx1q4/c/Q3n5Slq3/gOPPPIKNTWSaxBY90dEyKDs108ckSdOwK5dsHmz5IycPCnLjzZt4Fe/kpr8zp1FacTFSUbkqlWSONanj6ZTJz9ZWXZtVqoUinXtKgru5ElxWN5337U78ANc1wrA0PQMHCgOLq9X4u9+v8yQPXpIHN6yTi/J8vIUn30mG23A6d544jCMols3yat/8EGbxx+XwfT119Kn8LvvVO0AjmXz5tMh4qVLX2fp0tdRKop+/cqJjxfzvrRUnLrvvKPqti+LipK9+gYOtHE6Vd1+fWVlUqPw3Xfye7duFj//uU18vCw9Nm6URJ6lS626PRmUklDh8OHX13brRgEYrgiX63TH48DGqp98YtXbDw9OD/wzOXmS2g44UiBTWSnXUUpKZ4cN0wwdCidP7mTJkt+Qn/93bFtCxGFh44mIeJH8fLE0wsPF+RgfL8uMiIjTG3NGRMguzR07SnzfsqhN3xXL5MgRiQi88opFdbUImZ6uCQuT4+npUhR1PW19diZGARgajWXJmvvcPRguhoRSG2JCt6GqKo78/CqioqKorq7i4Ydjefnl1mjtY+9eMfO9Xkk2sm2RJzz89Aargc1QfD5x8u3fLynKDocom5EjJapx5szecPmubYwCMAQ9xcXFPPLII0yZMoW5c+fWdcE9vWnp2SFiQ8MxCsAQ9ARCxAAvvXTpELGh4bRgsarBYGhpjAIwGEIYowAMhhDGKACDIYRptAJwu90T3W53k3YDNhgMzUOjFYDH48luCkEMBkPzY5YABkMI06x5AE25N2BLEuwyBrt8EPwyBrt80DQyXlIB1O4AfO72X/lXsglIU+wNGCwEu4zBLh8Ev4zBLh80w96AZo1vMFy/NEUUIEv+c09sAnkMBkMz0hRbgy0DmqjRscFgaE5MFMBgCGGMAjAYQhijAAyGEMYoAIMhhDEKwGAIYYwCMBhCGKMADIYQxigAgyGEMQrAYAhhjAIwGEIYowAMhhDGKACDIYQxCsBgCGGMAjAYQhijAAyGEMYoAIMhhDEKwGAIYYwCMBhCGKMADIYQplE9Ad1utwtIr/032OPx/EdTCGUwGJqHxloAkwB3oHW42+1+tPEiGQyG5qJRFoDH45lzxq/pgNkk1GC4hmiSrcHcbnc6UHKp3YLM1mDNQ7DLB8EvY7DLB8G1NdhEj8fzk0tdy2wN1nwEu3wQ/DIGu3wQBFuDud3uiR6P58Xanwd6PJ71jZLIYDA0G42NAmQBL7jd7idrXzJRAIPhGqKxTsBlQNcmksVgMDQzJhHIYAhhjAIwGEIYowAMhhDGKACDIYQxCsBgCGGMAjAYQhijAAyGEMYoAIMhhDEKwGAIYYwCMBhCGKMADIYQxigAgyGEMQrAYAhhjAIwGEIYowAMhhDGKACDIYQxCsBgCGGMAjAYQhijAAyGEKbR+wLUNgYFuMNsDWYwXFs0ygJwu90DgYG1zUEH1m4QYjAYrhEa2xV4PbC+dpPQfI/Hk98kUhkMhmahSbYGA9zA7kudNHPmzCZ6O4PB0BQorfVFT2jg1mC43e7Xgc8utZOQwWAIHhq1NZjb7X4B2F27S7CX+orCYDAEMZe0AC5GrdMv4Pj7QUM2CDUYDMFDoxSAwWC4tjGJQAZDCGMUgMEQwjRVGPCKqI0weIH0WkfiZR2/2lzs/WtzHwI+kMEtlQXZ0HvkdrtfaAkZG/AdD6TWj9RSEaRgfw7PkOEnHo/njosc93KZMraYBVArcGCL8TNTiht0vKXlAyYB7sBD63a7H21O+Wrfs0H3qPb1Zs/SbKB8T9bew8SWyCRtwHOYxemwd36twmp2LhGNu+Kx0pJLgMFAIHMwHzj3xl7q+NXmou/v8XjmnKFp0884tzm55D2qHVQtlaF5UflqleY6t9udXns/g/EeeoD3ApZKbfZrsHHFY6UlFYDrnN+TLvP41aZB7187wErOTYxqJlzn/H4+GdNbMEXbdc7v58rXtfa1Erfb/Xrtsqq5Ofc9z5LR4/F4gdeB94BBzSPSZeM65/cGj5WWVABeLp44dKnjV5uGvv/EFsx/8HIRGd1ud1YLKaYAXi59D3fXDrJcoNmXUTTgHgLLPB5PV8AbMLeDDC9XOFZaUgGs47TmSgc+u8zjV5tLvr/b7Z7o8XherP25JdaGl5KxxO12Z9U+tOktIGNDvuMALuRBbm4uJePAM8z+5wnObNcrHistpgBqnRrptRrWdYYD47OLHQ8W+Wpff8Htdue63e5cWuDBaMA9XF/7WiL1zcRgkC8bcAWcVi3hYb+UjMAct9v9aO3xSS0YBciS/05bIE0xVkwmoMEQwphEIIMhhDEKwGAIYYwCMBhCGKMADIYQxigAgyGEMQrAYAhhjAIwGEKY/w8G3jHiWsbDhAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 288x216 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    # Initialize plot\n",
    "    f, ax = plt.subplots(1, 1, figsize=(4, 3))\n",
    "    \n",
    "    # Plot training data as black stars\n",
    "    ax.plot(train_x.numpy(), train_y.numpy(), 'k*', zorder=10)\n",
    "    \n",
    "    for i in range(min(num_samples, 25)):\n",
    "        # Plot predictive means as blue line\n",
    "        ax.plot(test_x.numpy(), output.mean[i].detach().numpy(), 'b', linewidth=0.3)\n",
    "        \n",
    "    # Shade between the lower and upper confidence bounds\n",
    "    # ax.fill_between(test_x.numpy(), lower.numpy(), upper.numpy(), alpha=0.5)\n",
    "    ax.set_ylim([-3, 3])\n",
    "    ax.legend(['Observed Data', 'Sampled Means'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simulate Loading Model from Disk\n",
    "\n",
    "Loading a fully Bayesian model from disk is slightly different from loading a standard model because the process of sampling changes the shapes of the model's parameters. To account for this, you'll need to call `load_strict_shapes(False)` on the model before loading the state dict. In the cell below, we demonstrate this by recreating the model and loading from the state dict.\n",
    "\n",
    "Note that without the `load_strict_shapes` call, this would fail."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state_dict = model.state_dict()\n",
    "model = ExactGPModel(train_x, train_y, likelihood)\n",
    "\n",
    "# Load parameters without standard shape checking.\n",
    "model.load_strict_shapes(False)\n",
    "\n",
    "model.load_state_dict(state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
